{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fb20dc5f110>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
    "\n",
    "\n",
    "torch.manual_seed(12046)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
    "tokenizer = AutoTokenizer.from_pretrained('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RewardModel(nn.Module):\n",
    "\n",
    "    def __init__(self, model):\n",
    "        '''\n",
    "        评分模型\n",
    "        参数\n",
    "        ----\n",
    "        model ：嵌入模型\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.embedding = model\n",
    "        # 评分建模头\n",
    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
    "\n",
    "    def forward(self, x, seq_len=None):\n",
    "        '''\n",
    "        向前传播\n",
    "        参数\n",
    "        ----\n",
    "        x ：torch.LongTensor，文本，形状为(B, T)或者(B, T, vs)，其中vs表示字典大小\n",
    "        seq_len ：torch.LongTensor，文本的实际长度，形状为(B)\n",
    "        返回\n",
    "        ----\n",
    "        score ：torch.FloatTensor，评分，形状为(B, 1)\n",
    "        '''\n",
    "        \n",
    "        B = x.shape[0]\n",
    "        T = x.shape[1]\n",
    "        # 文本的嵌入向量\n",
    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
    "        ind = torch.arange(B, device=x.device)\n",
    "        # 如果没有传入seq_len，则所有文本的实际长度都等于T\n",
    "        if seq_len == None:\n",
    "            seq_len = torch.tensor([T] * B)\n",
    "        # 获取最后一个词元的特征\n",
    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
    "        score = self.score(pooled_emb)          # (B,    1)\n",
    "        return score\n",
    "    \n",
    "    def get_last_hidden_state(self, x):\n",
    "        '''\n",
    "        获取文本的嵌入向量\n",
    "        '''\n",
    "        # 普通情况下，x的形状为(B, T)\n",
    "        if len(x.shape) == 2:\n",
    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
    "        # 如果使用了gumbel_softmax，则x的形状为(B, T, vs)\n",
    "        # 这种情况下，需要直接与embedding的模型参数进行计算\n",
    "        else:\n",
    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
    "            inputs_embeds = x @ w  # (B, T, vs) @ (vs, C) --> (B, T, C)\n",
    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
    "        return emb\n",
    "\n",
    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0., grad_fn=<MaxBackward1>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 验证评分模型计算正确\n",
    "# x的形状是(B, T)，x_hot的形状是(B, T, vs)\n",
    "x = torch.randint(0, tokenizer.vocab_size, (3, 4))\n",
    "x_hot = F.one_hot(x, num_classes=tokenizer.vocab_size).float()\n",
    "(r_model(x) - r_model(x_hot)).abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RLModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, llm, r_model):\n",
    "        '''\n",
    "        大语言模型与评分模型的拼接（错误方式）\n",
    "        参数\n",
    "        ----\n",
    "        llm ：大语言模型\n",
    "        r_model ：评分模型\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.llm = llm\n",
    "        self.r_model = r_model\n",
    "        # 冻结模型\n",
    "        for param in r_model.parameters():\n",
    "            param.requires_grad = False\n",
    "    \n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        '''\n",
    "        利用大语言模型生成文本（反复使用模型进行预测）\n",
    "        参数\n",
    "        ----\n",
    "        idx ：torch.LongTensor，背景文本，形状为(1, T)\n",
    "        max_new_tokens ：int，生成文本的最大长度\n",
    "        '''\n",
    "        model = self.llm\n",
    "        for _ in range(max_new_tokens):\n",
    "            logits = model(input_ids=idx).logits\n",
    "            logits = logits[:, -1, :]\n",
    "            probs = F.softmax(logits, dim=-1)\n",
    "            # 根据概率，随机生成下一个词元\n",
    "            idx_next = torch.multinomial(probs, num_samples=1)\n",
    "            idx = torch.cat((idx, idx_next), dim=1)\n",
    "        return idx\n",
    "    \n",
    "    def forward(self, idx):\n",
    "        '''\n",
    "        利用大语言模型生成文本，再使用评分模型对生成文本进行评分\n",
    "        参数\n",
    "        ----\n",
    "        idx ：torch.LongTensor，背景文本，形状为(1, T)\n",
    "        返回\n",
    "        ----\n",
    "        reward ：torch.FloatTensor，评分，形状为(1, 1)\n",
    "        '''\n",
    "        # 为了代码简洁，我们设置产生文本的长度\n",
    "        ans = self.generate(idx, 20)\n",
    "        # 对文本进行评分\n",
    "        reward = self.r_model(ans)\n",
    "        return reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = '1 + 2 = 3, 2 + 1 = 3, 1 + 2 ='\n",
    "ids = tokenizer(inputs, return_tensors='pt')\n",
    "model = RLModel(llm, r_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4, 3 + 1 = 5, 1 + 2 = 6 — Ha ha ha! In us\n"
     ]
    }
   ],
   "source": [
    "# 验证generate是正确的\n",
    "print(tokenizer.decode(model.generate(ids['input_ids'], 20)[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4 without action FARMADAM (same) Wooden child Servant use Intel SOCKS+\n"
     ]
    }
   ],
   "source": [
    "# 使用第三方库封装好的函数生成文本\n",
    "res = model.llm.generate(\n",
    "    input_ids=ids['input_ids'], max_new_tokens=20,\n",
    "    do_sample=True, top_k=0)[0]\n",
    "print(tokenizer.decode(res, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-9-b7dbb844b37d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# 将报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
     ]
    }
   ],
   "source": [
    "loss = -1 * model(ids['input_ids'])\n",
    "# 将报错，因为torch.multinomial不可微\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 928.,  926., 1631.,  340., 6175.])\n",
      "tensor([ 996.,  865., 1616.,  314., 6209.])\n"
     ]
    }
   ],
   "source": [
    "# 验证gumbel_softmax可以近似torch.multinomial\n",
    "logits = torch.randn(1, 5)\n",
    "probs = F.softmax(logits, dim=-1)\n",
    "# 使用torch.multinomial生成结果\n",
    "y = torch.multinomial(probs, num_samples=10000, replacement=True)\n",
    "print(torch.histogram(y.float(), bins=5).hist)\n",
    "# 使用gumbel_softmax生成结果\n",
    "gumbel_y = torch.argmax(F.gumbel_softmax(logits.repeat(10000, 1), tau=1, hard=True), dim=-1, keepdim=True)\n",
    "print(torch.histogram(gumbel_y.float(), bins=5).hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RLModelWithGumbel(nn.Module):\n",
    "    \n",
    "    def __init__(self, llm, r_model):\n",
    "        '''\n",
    "        大语言模型与评分模型的拼接（没有明显错误的方式，但也不是合适的方式）\n",
    "        参数\n",
    "        ----\n",
    "        llm ：大语言模型\n",
    "        r_model ：评分模型\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.llm = llm\n",
    "        self.r_model = r_model\n",
    "        # 冻结模型\n",
    "        for param in r_model.parameters():\n",
    "            param.requires_grad = False\n",
    "    \n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        '''\n",
    "        利用大语言模型生成文本（反复使用模型进行预测）\n",
    "        参数\n",
    "        ----\n",
    "        idx ：torch.LongTensor，背景文本，形状为(1, T)\n",
    "        max_new_tokens ：int，生成文本的最大长度\n",
    "        返回\n",
    "        ----\n",
    "        idx ：torch.LongTensor，背景文本 + 生成文本，形状为(1, T+L)，其中L是生成文本的长度\n",
    "        ans ：torch.LongTensor，生成文本，形状为(1, L, vs)，其中vs是字典的大小\n",
    "        '''\n",
    "        model = self.llm\n",
    "        ans = None\n",
    "        for _ in range(max_new_tokens):\n",
    "            logits = model(input_ids=idx).logits\n",
    "            logits = logits[:, -1, :]\n",
    "            # 根据概率，随机生成下一个词元\n",
    "            idx_next_hot = F.gumbel_softmax(logits, tau=1, hard=True)  # (1, vs)\n",
    "            # torch.argmax不可微，所以idx不可微\n",
    "            idx_next = torch.argmax(idx_next_hot, dim=-1, keepdim=True)\n",
    "            idx = torch.cat((idx, idx_next.long()), dim=1)\n",
    "            idx_next_hot = idx_next_hot.unsqueeze(1)  # (1, 1, vs)\n",
    "            if ans == None:\n",
    "                ans = idx_next_hot\n",
    "            else:\n",
    "                ans = torch.cat((ans, idx_next_hot), dim=1)\n",
    "        return idx, ans\n",
    "    \n",
    "    def forward(self, idx):\n",
    "        '''\n",
    "        利用大语言模型生成文本，再使用评分模型对生成文本进行评分\n",
    "        参数\n",
    "        ----\n",
    "        idx ：torch.LongTensor，背景文本，形状为(1, T)\n",
    "        返回\n",
    "        ----\n",
    "        reward ：torch.FloatTensor，评分，形状为(1, 1)\n",
    "        '''\n",
    "        # 为了代码简洁，我们设置产生文本的长度\n",
    "        _, ans = self.generate(idx, 20)\n",
    "        # 对生成的文本进行评分\n",
    "        reward = self.r_model(ans)\n",
    "        return reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_gumbel = RLModelWithGumbel(llm, r_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[True, True, True, True, True, True, True, True, True, True, True, True,\n",
      "         True, True, True, True, True, True, True, True]])\n",
      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 0, 1 + 1 = 0; extends laugh(cow, decision, discount) fifth person,\n"
     ]
    }
   ],
   "source": [
    "# 验证generate函数是否正确\n",
    "idx, ans = model_gumbel.generate(ids['input_ids'], 20)\n",
    "# 验证idx和ans的重叠部分是否相同\n",
    "print(idx[:, ids['input_ids'].shape[1]:] == torch.argmax(ans, dim=-1, keepdim=True).squeeze(-1))\n",
    "print(tokenizer.decode(idx[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.2085]]), tensor([[-0.2085]], grad_fn=<MmBackward0>))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 验证评分模型计算正确\n",
    "model_gumbel.r_model(idx[:, ids['input_ids'].shape[1]:]), model_gumbel.r_model(ans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.3994e-06,  4.8380e-06,  3.5403e-06,  ...,  4.4225e-06,\n",
       "         -1.5709e-06,  4.8997e-06],\n",
       "        [ 4.4208e-05,  1.3246e-04,  1.4072e-05,  ...,  7.9197e-05,\n",
       "         -1.4321e-06, -6.9506e-06],\n",
       "        [ 7.8832e-06,  5.7550e-06, -1.3545e-07,  ...,  5.6032e-06,\n",
       "         -5.2948e-06,  1.6141e-06],\n",
       "        ...,\n",
       "        [ 6.0610e-10,  9.2871e-10,  3.8407e-10,  ...,  1.6127e-09,\n",
       "         -1.6454e-09, -8.2414e-10],\n",
       "        [-1.5970e-09,  4.7921e-09,  6.8945e-09,  ...,  7.0852e-09,\n",
       "         -7.1524e-09, -1.9468e-09],\n",
       "        [ 3.6735e-04,  2.7833e-04,  3.1601e-05,  ...,  1.5014e-05,\n",
       "          3.1863e-04, -2.6312e-04]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = -1 * model_gumbel(ids['input_ids'])\n",
    "# 成功运行反向传播算法\n",
    "loss.backward()\n",
    "list(model_gumbel.llm.parameters())[0].grad"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
